import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiScaleBSQ(nn.Module):
    def __init__(self, codebook_dim=32, scale_schedule=None):
        super().__init__()
        # codebook size -> 2 ** codebook_dim
        self.codebook_dim = codebook_dim
        self.scale_lvls = len(scale_schedule)
        self.scale_schedule = scale_schedule
        self.bsq_quant = BSQ(codebook_dim=codebook_dim)

    def forward(self, f_BTC):
        B, T, C = f_BTC.size()
        quantized_out, residual = 0.0, f_BTC
        all_losses, all_bit_indices = [], []
        for lvl_idx, pt in enumerate(self.scale_schedule):
            interpolate_residual = (
                F.interpolate(residual.permute(0, 2, 1), size=(pt), mode="area").permute(0, 2, 1).contiguous()
                if pt != T
                else residual
            )
            quantized, bit_indices, loss = self.bsq_quant(interpolate_residual)
            quantized = (
                F.interpolate(quantized.permute(0, 2, 1), size=(T), mode="linear")
                .permute(0, 2, 1)
                .contiguous()
                if pt != T
                else quantized
            )
            residual = residual - quantized.detach()  # remove_residual_detach = False
            quantized_out = quantized_out + quantized
            all_bit_indices.append(bit_indices)
            all_losses.append(loss)
        # stack all losses and indices
        all_losses = torch.stack(all_losses, dim=-1)
        all_bit_indices = torch.cat(all_bit_indices, dim=1)
        return quantized_out, all_bit_indices, all_losses

    @torch.no_grad()
    def vqidx_to_feat(self, bit_indices, multi_scale=False):
        B, T, C = bit_indices.shape[0], self.scale_schedule[-1], self.codebook_dim
        ori_h_BTC = (bit_indices.float() * 2 - 1.0) / (self.codebook_dim**0.5)
        pn_start, pn_next = 0, self.scale_schedule[0]
        if multi_scale:
            ori_h_BCT = ori_h_BTC.permute(0, 2, 1).contiguous()
            f_hat = bit_indices.new_zeros(B, C, T, dtype=torch.float32)
            next_scales = []
            for pidx in range(self.scale_lvls - 1):
                h_BCT = F.interpolate(ori_h_BCT[..., pn_start:pn_next], size=(T), mode="linear")
                f_hat.add_(h_BCT)
                pn_start = pn_next
                pn_next = pn_next + self.scale_schedule[pidx + 1]
                next_scales.append(F.interpolate(f_hat, size=(self.scale_schedule[pidx + 1]), mode="area"))
            return torch.cat(next_scales, dim=-1).permute(0, 2, 1).contiguous()
        else:
            f_hat = bit_indices.new_zeros(B, T, C, dtype=torch.float32)
            for pidx in range(self.scale_lvls - 1):
                h_BCT = F.interpolate(
                    ori_h_BTC[:, pn_start:pn_next].permute(0, 2, 1).contiguous(), size=(T), mode="linear"
                )
                f_hat.add_(h_BCT.permute(0, 2, 1).contiguous())
                pn_start = pn_next
                pn_next = pn_next + self.scale_schedule[pidx + 1]
            f_hat.add_(ori_h_BTC[:, pn_start:])
            return f_hat


class BSQ(nn.Module):
    def __init__(self, codebook_dim=32):
        super().__init__()
        self.inv_temperature = 100.0
        self.commit_loss_weight = 0.2
        self.entropy_loss_weight = 0.1
        self.codebook_dim = codebook_dim

    def forward(self, f_BTC):
        f_BTC = F.normalize(f_BTC, dim=-1)
        # use straight-through gradients (optionally with custom activation fn) if training
        quantized = self.quantize(f_BTC)  # B, T, C
        # calculate loss
        persample_entropy, cb_entropy = self.soft_entropy_loss(f_BTC)
        entropy_penalty = (persample_entropy - cb_entropy) / self.inv_temperature
        commit_loss = torch.mean(((quantized.detach() - f_BTC) ** 2).sum(dim=-1))
        aux_loss = entropy_penalty * self.entropy_loss_weight + commit_loss * self.commit_loss_weight
        # gather the indices
        bit_indices = (quantized > 0).int()  # B, T, C
        return quantized, bit_indices, aux_loss

    def quantize(self, z):
        assert z.shape[-1] == self.codebook_dim, f"Expected {self.codebook_dim} dimensions, got {z.shape[-1]}"
        q_scale = 1.0 / (self.codebook_dim**0.5)
        zhat = torch.where(z > 0, torch.tensor(1).type_as(z), torch.tensor(-1).type_as(z))
        zhat = q_scale * zhat  # on unit sphere
        return z + (zhat - z).detach()

    def soft_entropy_loss(self, z):
        def get_entropy(count, dim=-1):
            H = -(count * torch.log(count + 1e-8)).sum(dim=dim)
            return H

        p = torch.sigmoid(-4 * z / (self.codebook_dim**0.5) * self.inv_temperature)
        prob = torch.stack([p, 1 - p], dim=-1)  # (b, l, codebook_dim, 2)
        per_sample_entropy = (
            get_entropy(prob, dim=-1).sum(dim=-1).mean()
        )  # (b,l, codebook_dim)->(b,l)->scalar
        # macro average of the probability of each subgroup
        avg_prob = prob.mean(dim=[0, 1])  # (codebook_dim, 2)
        codebook_entropy = get_entropy(avg_prob, dim=-1)
        # the approximation of the entropy is the sum of the entropy of each subgroup
        return per_sample_entropy, codebook_entropy.sum()
